The Switch Transformer

Google Brain’s language model that switches itself on and off

Julia Turc
Towards Data Science

--

In the last three years, Transformer-based language models (LMs) have been stealing the show in the natural language processing (NLP) world. Transformers are usually huge networks pre-trained on massive amounts of unstructured text, capturing generally useful linguistic properties. Pre-trained models can then be fine-tuned for a myriad of end-tasks like question answering or machine translation, even on modest amounts of labeled data (see this article for the latest trends in pre-training LMs). The T5 model, Google’s record holder on multiple NLP benchmarks, was recently outranked by its own Switch Transformer.

Not all knowledge is useful all the time. This observation, somewhat obvious in hindsight, is the key to efficiency in the Switch Transformer, as it decouples inference costs from the total size of the model.

The Switch Transformer is currently getting press coverage for its one trillion parameters. While the model size is indeed eye-catching, I find its efficiency story more compelling. Given Google’s track record of operating at scale, this achievement was not hard to anticipate; the main building blocks (Tensorflow, MeshTensorflow, TPUs) have been in place for quite a while. What is more remarkable is that the Switch Transformer can be pre-trained significantly faster than its T5 predecessor at the same computational cost (measured as FLOPS per input token). Given the ongoing trend of building ever-larger models, it is very encouraging to see that scaling up model size doesn’t necessarily mean using more energy.

Terminology disclaimer

I will use the name Switch Transformer to refer to the newly-proposed architecture rather than its specific 1.6-trillion instantiation. The paper scales this architecture on multiple dimensions, including the number of layers, self-attention heads, and “experts” (more on this below). Specific instantiations of the model, with a particular configuration and parameter count, have names like Switch-Base, Switch-Large, Switch-XXL and Switch-C (the latter having 1.6 trillion parameters). When contrasting the Switch Transformer against its T5 predecessor, the authors are particularly careful to compare compatible instantiations of the two architectures. I will be less precise just to keep the prose simple, but please do keep in mind that the Switch Transformer doesn’t only refer to Switch-C, its largest instantiation.

What does the transformer “switch”?

Similarly to how a hardware network switch forwards an incoming packet to the devices it was intended for, the Switch Transformer routes the input signal through the model, activating only a subset of its parameters. The implicit assumption is that not all the information stored in the model is relevant for a particular input. This is somewhat intuitive — presumably the documentation for Python 3.9.1 encountered during pre-training is not terribly helpful when reading a Shakespeare play. The observation that not all knowledge is useful all the time, somewhat obvious in hindsight, is the key to the Switch Transformer’s efficiency, as it decouples inference costs from the total size of the model.

Architectural Details

This section assumes you’re familiar with the Transformer architecture. If not, you can check out one of the myriad tutorials, like the Illustrated Transformer. In a nutshell, a Transformer is a deep stack of multi-headed self-attention layers. In the standard architecture, at the end of each layer, there’s a Feed-Forward Network (FFN) meant to aggregate the outputs coming from the multiple heads. The Switch Transformer replaces this single FFN with multiple FFNs, and calls them “experts” (arguably a hyperbola). On each forward pass, at each layer, for each token in the input, the model activates exactly one expert. In other words, during a forward pass, a Switch Transformer uses roughly as many parameters as a standard Transformer with the same number of layers (plus the routing parameters, which are negligible in size).

Figure 2 from the Switch Transformers paper. Left: standard Transformer block. Right: a Switch Transformer block. The single FFN is replaced with multiple FFNs (named “experts”). Each token is passed through its designated expert.

The decision to duplicate the FFN weights rather than some other parameters of the models (like the key/query/value matrices in self-attention) seems to have been made experimentally. The authors report that their attempts to add experts to these other parts of the models resulted in training instability.

How does the model decide which expert to switch on?

Given an intermediate token embedding x (produced by the layers below), the model needs to choose an FFN expert to pass it through. The decision process relies on a simple sequence of operations:

  1. The embedding x is multiplied by a routing matrix Wᵣ (a learnable parameter trained together with the rest of the model) to obtain a score for each expert: scores = x * Wᵣ
  2. The scores are normalized into a probability distribution, so that they sum up to 1 across experts: p = softmax(scores)
  3. The embedding x is passed through the expert i with highest probability. Finally, the output (i.e., the updated token embedding) is the activation produced by the expert, weighted by its probability score: x’ = pᵢ * Eᵢ (x)

The contributions of the Switch Transformer

It shows that a single expert can be enough.

Granted, the underlying idea of conditional computation within a neural network (where each input activates only a subset of the parameters) is not new. Previous studies like [2], published four years prior, explored mixture-of-experts layers in the context of LSTMs: on such layers, the network selects multiple experts and aggregates their outputs. Prior work posited that a minimum number of two experts was necessary for reliably training the routing parameters; the prior intuition was that the relevance of an expert to a given input can only be measured relatively to other experts — hence the need for contrasting at least two of them.

The Switch Transformer shows that selecting a single expert can be enough for training useful routing parameters, in the presence of an additional loss term that encourages uniform usage across experts. Admittedly, gaining a deeper understanding of why this loss is so effective requires further investigation. One potential explanation (not mentioned in the paper) is that, while each token on each layer activates a single expert, a training batch consists of multiple tokens. Since the newly-added loss term encourages expert diversity, a multi-token batch will make some sort of indirect comparison across multiple experts.

Activating a single expert saves FLOPS and communication costs (since only one expert needs to send its outputs to the rest of the network).

It provides solutions for training instability.

In comparison to the traditional Transformer, the Switch Transformer needs to jump additional hurdles. First, the hard switching mechanism that turns off certain parts of the model introduces sparsity. Model sparsity is known to cause training instability — in other words, sparse models can be sensitive to the random seed (which determines the initial parameter values, the shuffling of the training data, the values to be dropped out, etc.), so different training runs can lead to different performance. To fight training instability, the authors propose to reduce the initial values of the parameters by scaling down their standard deviation from 0.

Second, in order to reduce computational costs, the Switch Transformer uses the bfloat16 format (“Google Brain Floating Point”), in contrast to the more standard float32. Low precision is yet another cause of training instability. The authors address this by having the experts use float32 internally, while exposing a bfloat16 API to the rest of the network. Since each expert lives on a single device, the float32 values never leave the chip and inter-device communication remains at low resolution.

It offers an interesting study of scale and efficiency

Given the ongoing trend of building ever-larger models, it is very encouraging to see that scaling up model size doesn’t necessarily mean using more energy.

By varying aspects such as the computational budget and parameter count, the paper makes the following interesting observations:

  1. Increasing the pool of experts from 1 (equivalent to the standard Transformer) to 2, 4, 8 and so on up to 256 shows consistent increase in performance, without additional computational cost (since only one expert is activated regardless of the size of the pool).
  2. Switch Transformers achieve the same perplexity as their regular Transformer counterparts much faster. For instance, the Switch-Base model reaches the LM performance of a fully-converged T5-Base model in ~7x less time.

Downstream Performance

The authors show that the Switch-Base and Switch-Large instantiations exceed the performance of their T5-Base and T5-Large counterparts not only on language modelling, but also on a slew of downstream tasks like classification, coreference resolution, question answering, or summarization.

References

  1. Fedus et al.: Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity (2021)
  2. Shazeer et al.: Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer (2017)

--

--